# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from collections.abc import Callable, Sequence
import dataclasses
import functools
import itertools as it
from typing import TypeVar, Any, Union

import numpy as np

from jax._src import ad_checkpoint
from jax._src import api
from jax._src import api_util
from jax._src import callback
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import dtypes
from jax._src import effects
from jax._src import lax
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import numpy as jnp
from jax._src import pjit
from jax._src import shard_map as jshmap
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec as P
from jax._src.tree_util import tree_flatten
from jax._src.tree_util import tree_map
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
                           unzip3, weakref_lru_cache, HashableWrapper, foreach)

# Backward compatibility: some downstream users implicitly rely on this import,
# and reference jax.experimental.shard_map without an explicit import.
# TODO(yashkatariya): remove this once users are migrated to jax.shard_map.
try:
  import jax.experimental.shard_map as _  # pytype: disable=import-error  # noqa: F401
except ImportError:
  pass

source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = type['JaxException']
Payload = list[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
Out = TypeVar('Out')

## Utils

def popattr(obj, attrname):
  val = getattr(obj, attrname)
  delattr(obj, attrname)
  return val

def setnewattr(obj, name, val):
  sentinel = object()
  assert getattr(obj, name, sentinel) is sentinel
  setattr(obj, name, val)

# Concrete errors

class JaxException(Exception):
  """Python exception which can contain an error message with JAX run-time info."""

  def __init__(self, traceback_info):
    self.traceback_info = traceback_info
    # TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
    # self.with_traceback(self.traceback_info)

  def __init_subclass__(cls):
    jtu.register_pytree_node_class(cls)

  def tree_flatten(self):
    return ([], self.traceback_info)

  @classmethod
  def tree_unflatten(cls, metadata, payload):
    del payload
    return cls(metadata)

  def get_effect_type(self) -> ErrorEffect:
    raise NotImplementedError


@functools.total_ordering
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect(effects.Effect):
  error_type: type[JaxException]
  shape_dtypes: tuple[api.ShapeDtypeStruct, ...]

  def __lt__(self, other: ErrorEffect):
    shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype))  # dtype is not comparable
                                   for sd in x.shape_dtypes)
    unpack = lambda x: (str(x.error_type), shape_dtypes(x))
    return (unpack(self) < unpack(other))

effects.lowerable_effects.add_type(ErrorEffect)
effects.control_flow_allowed_effects.add_type(ErrorEffect)
effects.custom_derivatives_allowed_effects.add_type(ErrorEffect)
effects.remat_allowed_effects.add_type(ErrorEffect)

class DivisionByZeroError(JaxException):

  def __str__(self):
    return 'division by zero'

  def get_effect_type(self):
    return ErrorEffect(DivisionByZeroError, ())

class NaNError(JaxException):

  def __init__(self, traceback_info, primitive_name):
    super().__init__(traceback_info)
    self.prim = primitive_name

  def tree_flatten(self):
    return ([], (self.traceback_info, self.prim))

  @classmethod
  def tree_unflatten(cls, metadata, _):
    return cls(*metadata)

  def get_effect_type(self):
    return ErrorEffect(NaNError, ())

  def __str__(self):
    return f'nan generated by primitive: {self.prim}.'

class OOBError(JaxException):

  def __init__(self, traceback_info, primitive_name, operand_shape, payload):
    super().__init__(traceback_info)
    self.prim = primitive_name
    self.operand_shape = operand_shape
    self._payload = payload

  def tree_flatten(self):
    return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))

  @classmethod
  def tree_unflatten(cls, metadata, payload):
    return cls(*metadata, payload[0])

  def __str__(self):
    return (f'out-of-bounds indexing for array of '
            f'shape {self.operand_shape}: '
            f'index {self._payload[0]} is out of bounds for axis '
            f'{self._payload[1]} with size {self._payload[2]}. ')

  def get_effect_type(self):
    return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), np.int32),))

class FailedCheckError(JaxException):

  def __init__(self, traceback_info, fmt_string, *a, **k):
    super().__init__(traceback_info)
    self.fmt_string = fmt_string
    self.args = a
    self.kwargs = k

  def tree_flatten(self):
    return ((self.args, self.kwargs),  # leaves
            (self.traceback_info, self.fmt_string))  # treedef

  @classmethod
  def tree_unflatten(cls, metadata, payload):
    args, kwargs = payload
    return cls(*metadata, *args, **kwargs)

  def __str__(self):
    return (self.fmt_string.format(*self.args, **self.kwargs)
            + ' (`check` failed)')

  def get_effect_type(self):
    vals = jtu.tree_leaves((self.args, self.kwargs))
    return ErrorEffect(
        FailedCheckError,
        tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))

@dataclasses.dataclass
class BatchedError(JaxException):
  error_mapping: dict[tuple[int, ...], JaxException]

  def __post_init__(self):
    traceback_info = list(self.error_mapping.values())[0].traceback_info
    super().__init__(traceback_info)


  def __str__(self):
    return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
                     for idx, e in self.error_mapping.items())


# Error Value

@jtu.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Error:
  _pred: dict[ErrorEffect, Bool]
  _code: dict[ErrorEffect, Int]
  _metadata: dict[Int, PyTreeDef]  # mapping of code to JaxException treedef.
  _payload: dict[ErrorEffect, Payload]

  def get(self) -> str | None:
    """Returns error message if error happened, None if no error happened."""
    exp = self.get_exception()
    if exp is not None:
      return str(exp)
    return None

  def get_exception(self) -> JaxException | None:
    """Returns Python exception if error happened, None if no error happened."""
    if any(map(np.shape, self._pred.values())):
      return self._get_batched_exception()
    else:
      min_code = None
      cur_effect = None
      for error_effect, code in self._code.items():
        if self._pred[error_effect]:
          if min_code is None or code < min_code:
            min_code = code
            cur_effect = error_effect

      if cur_effect is not None:
        return tree_unflatten(self._metadata[int(min_code)],  # type: ignore
                              self._payload[cur_effect])
    return None

  def throw(self):
    _check_error(self)

  def __str__(self):
    return f'Error({self.get()})'

  # Internal helpers

  def _get_batched_exception(self) -> BatchedError | None:
    shape = np.shape(list(self._pred.values())[0])
    error_mapping = {}
    for idx in np.ndindex(*shape):
      min_code = None
      cur_effect = None
      for error_effect, code in self._code.items():
        if self._pred[error_effect][idx]:   # type: ignore
          if min_code is None or code[idx] < min_code:  # type: ignore[index]
            min_code = code[idx]   # type: ignore
            cur_effect = error_effect

      if cur_effect is not None:
        payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
        jax_error = tree_unflatten(self._metadata[int(min_code)], payload)  # type: ignore
        error_mapping[idx] = jax_error
    if error_mapping:
      return BatchedError(error_mapping)
    else:
      return None

  def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
    new_errs = {**self._pred, **{effect_type: pred}}
    new_codes = {**self._code, **{effect_type: code}}
    new_payload = {**self._payload, **{effect_type: payload}}
    new_metadata = {**self._metadata, **metadata}
    return Error(new_errs, new_codes, new_metadata, new_payload)

  def _add_placeholder_effects(self, effects: set[ErrorEffect]):
    """Fill out Error with `effects` and np.ones arrays of their payloads."""
    new_err = self._pred.copy()
    new_code = self._code.copy()
    new_payload = self._payload.copy()
    for effect in effects:
      if effect not in self._pred.keys():
        new_err[effect] = False
        new_payload[effect] = list(
            tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
        # The error value associated with this effect will never become True, so
        # we don't need to set a meaningful code.
        new_code[effect] = -1
    return Error(new_err, new_code, self._metadata, new_payload)

  def _replace(self, *args, **kwargs):
    return dataclasses.replace(self, *args, **kwargs)

  # PyTree methods

  def tree_flatten(self):
    return ((self._pred, self._code, self._payload), (self._metadata))

  @classmethod
  def tree_unflatten(cls, metadata, data):
    pred, code, payload = data
    return cls(pred, code, metadata, payload)

init_error = Error({}, {}, {}, {})  # value used as initial (empty) error.
next_code = it.count(1).__next__  # globally unique ids, could be uuid4

def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
  code = next_code()
  effect_type = new_error.get_effect_type()
  new_payload, new_metadata = tree_flatten(new_error)
  return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)

def update_error(error, pred, code, metadata, payload, effect_type):
  err_of_type = error._pred.get(effect_type, False)
  out_err = err_of_type | pred
  out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
  cur_payload = error._payload.get(effect_type, None)
  if cur_payload is not None:
    out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
  else:
    out_payload = payload
  return error._update(effect_type, out_err, out_code, metadata, out_payload)


## Checkify transformation for plumbing functional error values.

@lu.transformation_with_aux2
def _flatten_and_get_error_metadata_thunk(f, store, *invals):
  error, out = f(*invals)
  out_vals, out_tree = jtu.tree_flatten((error, out))
  store.store((out_tree, set(error._pred.keys())))
  return out_vals

def default_checkify_rule(primitive: core.Primitive, error: Error,
                          enabled_errors, *invals: core.Value,
                          **params: Any) -> tuple[Error, Sequence[core.Value]]:
  """Default rule for primitives in `checkify` interpreter."""
  if 'call_jaxpr' not in params:
    # Default non-HOP case: just call primitive and don't update error.
    return error, primitive.bind(*invals, **params)

  # Code below handles call- and map-primitives, by recursively calling
  # checkify_jaxpr.
  err_vals, err_tree = jtu.tree_flatten(error)
  num_error_vals = len(err_vals)
  if 'donated_invars' in params:
    params = dict(params, donated_invars=(*[False]*num_error_vals,
                                          *params['donated_invars']))

  # call_jaxpr handling
  call_jaxpr = params.pop('call_jaxpr')
  if isinstance(call_jaxpr, core.ClosedJaxpr):  # handle closed_call_p
    jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
  else:
    jaxpr, consts = call_jaxpr, ()
  consts_ = tuple(HashableWrapper(c) for c in consts)
  partial_checkify = lu.hashable_partial(
      lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info),
      jaxpr, consts_, enabled_errors, err_tree)
  partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
      partial_checkify)

  # map-specific params handling.
  if isinstance(primitive, core.MapPrimitive):
    # Update `in_axes` and `out_axes_thunk` params for map primitive.
    out_val_axes = params.pop('out_axes')

    @as_hashable_function(closure=out_val_axes)
    def out_axes_thunk():
      out_err_num = metadata()[0].num_leaves - len(out_val_axes)
      return (*(0,)*out_err_num, *out_val_axes)

    params = dict(params,
                  in_axes=(*(None,)*num_error_vals, *params['in_axes']),
                  out_axes_thunk=out_axes_thunk)

  all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params)

  out_tree, _ = metadata()
  error, out_vals = tree_unflatten(out_tree, all_vals)
  if isinstance(primitive, core.MapPrimitive):
    error = _reduce_any_error(error)
  return error, out_vals

def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
                   error: Error, *args) -> tuple[Error, list[core.Value]]:
  err_vals, err_tree = jtu.tree_flatten(error)
  return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
                             enabled_errors, err_tree, *err_vals, *args)

def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
                        enabled_errors, err_tree: PyTreeDef,
                        *args: core.Value) -> tuple[Error, list[Any]]:
  env: dict[core.Var, Any] = {}
  err_vals, in_args = split_list(args, [err_tree.num_leaves])
  error = jtu.tree_unflatten(err_tree, err_vals)

  last_used = core.last_used(jaxpr)

  def read_env(var: core.Atom):
    if isinstance(var, core.Literal):
      return var.val
    return env[var]

  def write_env(var: core.Var, val: Any):
    env[var] = val

  foreach(write_env, jaxpr.constvars, consts)
  foreach(write_env, jaxpr.invars, in_args)

  # interpreter loop
  for eqn in jaxpr.eqns:
    invals = map(read_env, eqn.invars)
    checkify_rule = error_checks.get(
        eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive))
    name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
    with source_info_util.user_context(eqn.source_info.traceback,
                                       name_stack=name_stack):
      error, outvals = checkify_rule(error, enabled_errors,
                                     *invals, **eqn.params)
    if eqn.primitive.multiple_results:
      foreach(write_env, eqn.outvars, outvals)
    else:
      write_env(eqn.outvars[0], outvals)
    core.clean_up_dead_vars(eqn, env, last_used)

  return error, map(read_env, jaxpr.outvars)

def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
                                 err_tree, *args):
  consts = tuple(c.x for c in hashable_consts)
  return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)

@lu.transformation_with_aux2
def flatten_fun_output(f, store, *args):
  ans = f(*args)
  ans, out_tree = tree_flatten(ans)
  store.store(out_tree)
  return ans


def _reduce_any_error(error: Error):
  out_error = init_error
  for error_effect in error._pred.keys():
    errs, codes, payloads = (error._pred[error_effect],
                             error._code[error_effect],
                             error._payload[error_effect])
    reduced_idx = jnp.argsort(errs)[-1]
    pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
                                   (errs, codes, payloads))
    out_error = out_error._update(error_effect, pred, code, {}, payload)

  out_error = out_error._replace(_metadata=error._metadata)
  return out_error

## check_p primitive

check_p = core.Primitive('check')
check_p.is_effectful = lambda _: True  # type: ignore
check_p.multiple_results = True  # zero results


def _pp_check(eqn, context, settings) -> core.pp.Doc:
  annotation = (source_info_util.summarize(eqn.source_info)
                if settings.source_info else None)
  name_stack_annotation = (f'[{eqn.source_info.name_stack}]'
                           if settings.name_stack else None)
  trimmed_params = sorted((k, v) for (k, v) in eqn.params.items()
                          if k != "err_tree")
  rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation),
         core.pp_kv_pairs(trimmed_params, context, settings),
         core.pp.text(" ") + core.pp_vars(eqn.invars, context)]
  return core.pp.concat([core.pp.text("", annotation), *rhs])

core.pp_eqn_rules[check_p] = _pp_check

# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
  pass

@check_p.def_impl
def check_impl(*args, err_tree, debug):
  if debug:
    # NOOP (check will only trigger when discharged)
    return []
  error = tree_unflatten(err_tree, args)
  exc = error.get_exception()
  if exc:
    filtered_tb = traceback_util.filter_traceback(
        exc.traceback_info.as_python_traceback())
    exc.with_traceback(filtered_tb)
    raise JaxRuntimeError(str(exc)) from exc
  return []

@check_p.def_effectful_abstract_eval
def check_abstract_eval(*args, err_tree, debug):
  del debug
  return [], set(tree_unflatten(err_tree, args)._pred.keys())

# TODO(lenamartens) add in-depth error explanation to link to in module docs.
functionalization_error = ValueError(
    'Cannot abstractly evaluate a checkify.check which was not'
    ' functionalized. This probably means you tried to stage'
    ' (jit/scan/pmap/...) a `check` without functionalizing it'
    ' through `checkify.checkify`.'
    )

def check_lowering_rule(ctx, *args, err_tree, debug):
  if debug:
    # NOOP (check will only trigger when discharged)
    return []
  if not config.xla_runtime_errors.value:
    raise functionalization_error

  out_op, _, _ = callback.emit_python_callback(
      ctx, callback=functools.partial(python_err, err_tree),
      token=None,
      operands=args,
      operand_avals=list(ctx.avals_in),
      result_avals=list(ctx.avals_out),
      has_side_effect=True,
      returns_token=False)
  return out_op

def check_lowering_rule_unsupported(*a, debug, **k):
  if debug:
    return []
  raise functionalization_error

def python_err(err_tree, *args):
  error = tree_unflatten(err_tree, args)
  _check_error(error)
  return []

mlir.register_lowering(check_p, check_lowering_rule_unsupported,
                       platform='tpu')
mlir.register_lowering(check_p, check_lowering_rule,
                       platform='cpu')
mlir.register_lowering(check_p, check_lowering_rule,
                       platform='gpu')

def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
  size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
              if dim is not batching.not_mapped)
  batched_args = (batching.bdim_at_front(a, d, size)
                  for a, d in zip(batched_args, batch_dims))
  err = tree_unflatten(err_tree, batched_args)
  _check_error(err, debug=debug)
  return [], []
batching.primitive_batchers[check_p] = check_batching_rule

def check_jvp_rule(primals, _, *, err_tree, debug):
  # Check primals, discard tangents.
  check_p.bind(*primals, err_tree=err_tree, debug=debug)
  return [], []
ad.primitive_jvps[check_p] = check_jvp_rule

## checkify rules

ErrorCheckRule = Callable  # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: dict[core.Primitive, ErrorCheckRule] = {}


def get_traceback():
  return source_info_util.current().traceback

def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
  out = prim.bind(*in_vals, **params)
  err = check_nans(prim, error, enabled_errors, out)
  return err, out

def check_nans(prim, error, enabled_errors, out):
  if NaNError not in enabled_errors:
    return error

  def isnan(x):
    if dtypes.issubdtype(x.dtype, dtypes.prng_key):
      return False
    return jnp.any(jnp.isnan(x))

  any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
              if prim.multiple_results else isnan(out))
  return assert_func(error, any_nans, NaNError(get_traceback(), prim.name))


# All primitives which can generate a NaN.
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
                  lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
                  lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
                  lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
                  lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
                  lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
                  lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
                  lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
                  lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
                  lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
                  lax.reduce_p, lax.reduce_prod_p,
                  lax.reduce_sum_p, lax.reduce_window_p,
                  lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
                  lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
                  lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]

for _prim in nan_primitives:
  error_checks[_prim] = functools.partial(nan_error_check, _prim)


def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes):
  out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes)

  if OOBError not in enabled_errors:
    return error, out

  start_indices = jnp.array(start_indices)
  operand_dims = np.array(operand.shape, dtype=start_indices.dtype)
  slice_sizes = np.array(slice_sizes, dtype=start_indices.dtype)
  oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)

  payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
  error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_slice", operand.shape, payload))
  return error, out
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check

def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices):
  out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)

  if OOBError not in enabled_errors:
    return error, out

  operand_dims = np.array(operand.shape)
  update_dims = np.array(update.shape)
  start_indices = jnp.array(start_indices)
  oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims)

  payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
  error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_update_slice", operand.shape, payload))
  return error, out
error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check

def gather_error_check(error, enabled_errors, operand, start_indices, *,
                       dimension_numbers, slice_sizes, unique_indices,
                       indices_are_sorted, mode, fill_value):
  out = lax.gather_p.bind(
      operand, start_indices, dimension_numbers=dimension_numbers,
      slice_sizes=slice_sizes, unique_indices=unique_indices,
      indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)

  if OOBError not in enabled_errors:
    return error, out

  # compare to OOB masking logic in lax._gather_translation_rule
  dnums = dimension_numbers
  operand_dims = np.array(operand.shape)
  num_batch_dims = len(start_indices.shape) - 1

  upper_bound = operand_dims[np.array(dnums.start_index_map)]
  upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
  upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
  oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))

  payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
  error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "gather", operand.shape, payload))
  return error, out
error_checks[lax.gather_p] = gather_error_check

def div_error_check(error, enabled_errors, x, y):
  """Checks for division by zero and NaN."""
  if DivisionByZeroError in enabled_errors:
    any_zero = jnp.any(jnp.equal(y, 0))
    error = assert_func(error, any_zero, DivisionByZeroError(get_traceback()))
  return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check

def oob_payload(oob_mask, indices, dims_map, operand_shape):
  # Get first OOB index, axis and axis size so it can be added to the error msg.
  flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
  multi_idx = jnp.unravel_index(flat_idx, indices.shape)
  oob_axis = jnp.array(dims_map)[multi_idx[-1]]
  oob_axis_size = jnp.array(operand_shape)[oob_axis]
  oob_index = jnp.ravel(indices)[flat_idx]
  payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=np.int32)
  return payload

def scatter_oob(operand, indices, updates, dnums):
  # Ref: see clamping code used in scatter_translation_rule
  slice_sizes = []
  pos = 0
  for i in range(len(operand.shape)):
    if i in dnums.inserted_window_dims:
      slice_sizes.append(1)
    else:
      slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
      pos += 1

  upper_bound = np.array([operand.shape[i] - slice_sizes[i]
                          for i in dnums.scatter_dims_to_operand_dims],
                         np.int64)
  upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
  upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
                                     (len(indices.shape) - 1,))

  lower_oob = jnp.less(indices, 0)
  upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
  oob_mask = jnp.logical_or(lower_oob, upper_oob)
  payload = oob_payload(oob_mask, indices,
                        dnums.scatter_dims_to_operand_dims, operand.shape)
  return jnp.any(oob_mask), payload

def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
                        *, update_jaxpr, update_consts, dimension_numbers,
                        indices_are_sorted, unique_indices, mode):
  """Checks if indices are within bounds and update does not generate NaN."""
  out = prim.bind(
      operand, indices, updates, update_jaxpr=update_jaxpr,
      update_consts=update_consts, dimension_numbers=dimension_numbers,
      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
      mode=mode)

  if OOBError not in enabled_errors:
    return error, out

  out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
  oob_error = OOBError(get_traceback(), prim.name, operand.shape, payload)
  error = assert_func(error, out_of_bounds, oob_error)
  error = check_nans(prim, error, enabled_errors, out)
  return error, out
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_add_p)
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_mul_p)
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_min_p)
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_max_p)

# HOP error check rules

@jtu.register_static
class ErrorEffects:
  def __init__(self, val):
    self.val = val

@weakref_lru_cache
def jaxpr_to_checkify_jaxpr(
    jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
    *flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:

  def fun_wrapped(*invals):
    error, out = checkify_jaxpr_flat(
        jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree, *invals)
    error_effects = ErrorEffects(set(error._pred.keys()))
    return (error, out), error_effects

  debug_info = jaxpr.jaxpr.debug_info.with_unknown_names()
  checked_jaxpr, full_out_tree = pe.trace_to_jaxpr(
      fun_wrapped, None, flat_err_and_in_vals, debug_info)
  out_tree, error_effects_treedef = full_out_tree.children()
  error_effects = error_effects_treedef.unflatten(()).val
  return checked_jaxpr, out_tree, error_effects

def cond_error_check(error: Error, enabled_errors, index, *ops,
                     branches, **params):
  # Get the error-effects out of all branches so the cond can be called with
  # a merged error with all these effects.
  err_vals, err_tree = jtu.tree_flatten(error)
  in_avals = map(core.get_aval, [*err_vals, *ops])
  def get_error_effects_from_jaxpr(jxpr):
    _, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
                                            *in_avals)
    return effects
  effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches]
  merged_error = error._add_placeholder_effects(set().union(*effects))
  err_vals, err_tree = jtu.tree_flatten(merged_error)

  # Update branch jaxprs to be checkified jaxprs.
  in_avals = map(core.get_aval, [*err_vals, *ops])
  new_branches, out_trees, _ = unzip3(
      jaxpr_to_checkify_jaxpr(
          jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)

  err_and_outs = lax.cond_p.bind(
      index, *err_vals, *ops,
      branches=tuple(new_branches), **params)

  # we need to merge metadata across out_trees (a tuple)
  err0, out = tree_unflatten(out_trees[0], err_and_outs)
  merged_metadata = err0._metadata
  for tr in out_trees[1:]:
    err, _ = tree_unflatten(tr, err_and_outs)
    merged_metadata = {**merged_metadata, **err._metadata}
  return err0._replace(_metadata=merged_metadata), out
error_checks[lax.cond_p] = cond_error_check

def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
                     num_consts, num_carry, linear, unroll, _split_transpose):

  consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
  xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs]
  # Query body effects to create a merged error containing all effects (such
  # that in and out carried error are of the same type).
  err_vals, err_tree = jtu.tree_flatten(error)
  new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
  _, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
                                          err_tree, *new_in_aval)

  merged_error = error._add_placeholder_effects(effects)
  err_vals, err_tree = jtu.tree_flatten(merged_error)

  # Create checked-jaxpr, with the needed pre-processing on the inputs.
  new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
  checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
                                                        err_tree, *new_in_aval)

  tomove = ([False] * len(err_vals) + [True] * len(consts)
            + [False] * (len(carry) + len(xs)))
  checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
  new_in_flat = [*consts, *err_vals, *carry, *xs]
  new_linear = (*[False] * len(err_vals), *linear)
  err_and_out = lax.scan_p.bind(
      *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
      num_consts=len(consts), num_carry=len(carry)+len(err_vals),
      linear=new_linear, unroll=unroll, _split_transpose=_split_transpose)
  err, out = tree_unflatten(out_tree, err_and_out)
  return err, out

error_checks[lax.scan_p] = scan_error_check

def checkify_while_body_jaxpr(
    cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
    enabled_errors, error: Error,
    c_consts_num: int) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
  cond_f = core.jaxpr_as_fun(cond_jaxpr)
  body_f = core.jaxpr_as_fun(body_jaxpr)
  def new_body_f(*c_consts_and_vals):
    c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
    out = body_f(*vals)
    # This checks if the next cond application will error
    lax.dce_sink(cond_f(*c_consts, *out))
    return out
  c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
  jaxpr, _ = pe.trace_to_jaxpr(
      new_body_f, None,
      (*c_consts_avals, *body_jaxpr.in_avals),
      debug_info=body_jaxpr.jaxpr.debug_info.with_unknown_names())
  err_vals, err_tree = jtu.tree_flatten(error)
  err_vals = map(core.get_aval, err_vals)
  flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
  jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
      jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
  return jaxpr, out_tree, error_effects


@weakref_lru_cache
def ignore_error_output_jaxpr(jaxpr, num_error_vals: int):
  """Constructs a checked jaxpr which does not output its error value."""
  consts = jaxpr.consts
  jaxpr = jaxpr.jaxpr
  new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
  return core.ClosedJaxpr(new_jaxpr, consts)

def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
                           cond_jaxpr, body_nconsts, body_jaxpr):
  if cond_jaxpr.out_avals[0].shape:
    # TODO(lenamartens, sharadmv): support batched while.
    raise ValueError('Checkify does not support batched while-loops '
                     '(checkify-of-vmap-of-while). \nHint: if possible, move '
                     'the vmap to the outer level to get '
                     'vmap-of-checkify-of-while.')

  c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
  # Check if the first cond application will error.
  error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry)

  _, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr,
                                                  enabled_errors, error,
                                                  cond_nconsts)
  # merged error!
  error = error._add_placeholder_effects(error_effects)
  err_vals, err_tree = jtu.tree_flatten(error)
  checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr(
      cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts)
  num_error_vals = len(err_vals)
  to_move = ([False] * num_error_vals + [True] * cond_nconsts
             + [True] * body_nconsts + [False] * len(carry))
  checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)

  cond_in_flat = [*err_vals, *c_consts, *carry]
  cond_in_flat = map(core.get_aval, cond_in_flat)
  checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
                                                     err_tree, *cond_in_flat)
  compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
  to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
  compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)

  new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry]
  all_out_vals = lax.while_p.bind(
      *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
      body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr)
  # body_out_tree will have all the metadata of cond because it executes a cond!
  error, out = tree_unflatten(body_out_tree, all_out_vals)
  return error, out
error_checks[lax.while_p] = while_loop_error_check

def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
                     in_shardings, out_shardings,
                     in_layouts, out_layouts,
                     donated_invars, ctx_mesh, name, inline, keep_unused,
                     compiler_options_kvs):
  # jaxpr to checked_jaxpr
  err_vals, err_tree = jtu.tree_flatten(error)
  new_vals_in = [*err_vals, *vals_in]
  in_avals = tuple(map(core.get_aval, new_vals_in))
  checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
                                                       err_tree, *in_avals)

  # Update pjit params to account for extra error values.
  num_error_vals = len(err_vals)
  num_out_error_vals = out_tree.num_leaves - len(out_shardings)
  sharding = sharding_impls.UNSPECIFIED
  new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
  new_in_layouts = (*[None] * num_error_vals, *in_layouts)
  new_donated_invars = (*[False] * num_error_vals, *donated_invars)

  new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
  new_out_layouts = (*[None] * num_out_error_vals, *out_layouts)

  err_and_out = pjit.jit_p.bind(
      *new_vals_in,
      jaxpr=checked_jaxpr,
      in_shardings=new_in_shardings,
      out_shardings=new_out_shardings,
      in_layouts=new_in_layouts,
      out_layouts=new_out_layouts,
      donated_invars=new_donated_invars,
      ctx_mesh=ctx_mesh,
      name=name,
      inline=inline,
      keep_unused=keep_unused,
      compiler_options_kvs=compiler_options_kvs,
  )
  return tree_unflatten(out_tree, err_and_out)
error_checks[pjit.jit_p] = pjit_error_check


def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
  err_vals, err_tree = jtu.tree_flatten(error)
  new_vals_in = [*err_vals, *vals_in]
  in_avals = tuple(map(core.get_aval, new_vals_in))
  checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
      pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
  checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
  err_and_out = ad_checkpoint.remat_p.bind(*new_vals_in, jaxpr=checked_jaxpr,
                                           **params)
  return tree_unflatten(out_tree, err_and_out)
error_checks[ad_checkpoint.remat_p] = remat_error_check


def shard_map_error_check(
    error: Error, enabled_errors, *vals_in,
    jaxpr: core.Jaxpr, in_specs, out_specs, **kwargs
):
  if (mesh := kwargs.get('mesh')) is None:
    raise ValueError('Mesh must be provided for shard_map with checkify.')

  err_vals, err_tree = jtu.tree_flatten(error)
  num_error_vals = len(err_vals)
  # Replicated sharding for in errors.
  new_in_specs = (*([P()] * num_error_vals), *in_specs)
  new_vals_in = [*err_vals, *vals_in]
  in_avals = list(map(core.get_aval, new_vals_in))
  manual_axes = kwargs.get('manual_axes')
  check_vma = kwargs.get('check_vma')
  for i, v in enumerate(in_avals):
    if not (sharder := core.shard_aval_handlers.get(type(v))):
      raise ValueError(f'Unsupported aval type: {type(v)}')
    in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_specs[i], v)

  with (jshmap._extend_axis_env(mesh, manual_axes),
        mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)),  # type: ignore[arg-type]
        config._check_vma(check_vma)):
    # jaxpr to checked_jaxpr
    checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
        pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
    )
  num_out_error_vals = out_tree.num_leaves - len(out_specs)

  def expand_errors_leading_dim(*xs):
    outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
    errs, outs = split_list(outs, [num_out_error_vals])
    errs = [lax.expand_dims(e, [0]) for e in errs]
    return *errs, *outs

  with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma):
    checked_jaxpr, _ = pe.trace_to_jaxpr(
        expand_errors_leading_dim, None,
        tuple(checked_jaxpr.in_avals),
        debug_info=checked_jaxpr.jaxpr.debug_info)

  # Update shard_map params to account for extra error values.
  # Use fully sharded partitioning for out errors.
  new_out_specs = (*([P(mesh.axis_names)] * num_out_error_vals), *out_specs)
  subfun = lu.hashable_partial(
      lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info),
      checked_jaxpr.jaxpr, checked_jaxpr.consts
  )
  new_params = dict(
      jaxpr=checked_jaxpr.jaxpr,
      in_specs=new_in_specs,
      out_specs=new_out_specs,
      **kwargs,
  )
  _, new_params = jshmap.shard_map_p.get_bind_params(new_params)

  err_and_out = jshmap.shard_map_p.bind(subfun, *new_vals_in, **new_params)
  return tree_unflatten(out_tree, err_and_out)
error_checks[jshmap.shard_map_p] = shard_map_error_check

def custom_jvp_call_rule(in_err: Error,
                         enabled_errors: set, *in_vals, num_consts,
                         jvp_jaxpr_fun: lu.WrappedFun,
                         call_jaxpr: core.ClosedJaxpr, **params):
  # The types to have in mind are:
  #   jvp : (a -> b) -> (a, T a) -> (b, T b)
  #   checkify : (a -> b) -> a -> Err b
  #   jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
  # where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
  # where the other Err' components are trivial (of float0 dtype).
  # Semantically, we don't add checks to the JVP rule. To check the result of a
  # JVP rule, one must instead use checkify-of-jvp. Thus this implementation
  # just forwards the input error and code (and trivial tangents) to the output.
  err_vals, err_tree = jtu.tree_flatten(in_err)
  partial_checkify = lu.wrap_init(
      functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
                        call_jaxpr.consts, enabled_errors, err_tree),
      debug_info=call_jaxpr.jaxpr.debug_info)
  partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
      partial_checkify)
  jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun)
  jvp, jvp_out_tree = flatten_fun_output(jvp)
  all_outs = custom_derivatives.custom_jvp_call_p.bind(
      partial_checkify, jvp, *err_vals, *in_vals, **params)
  fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree)
  if fst:
    err_and_out_tree, _ = out_metadata
    out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
  else:
    err_vals, out_vals = split_list(all_outs, [len(err_vals)])
    # forward input error to output
    out_err = jtu.tree_unflatten(err_tree, err_vals)
  return out_err, out_vals
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule

# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
# outputs that checkify adds (just forwarding the error data's primal and
# tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
# Adding another layer of lu.transformation was tricky, though maybe doable.
def lift_jvp(num_errs: int, num_consts: int,
             jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
  def jvp(*xs):
    n, ragged = divmod(len(xs), 2)
    assert not ragged
    primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
    zeros = [type(t) is SymbolicZero for t in tangents]
    jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
    nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
    out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
    out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
    nz_out_tangents_ = iter(nz_out_tangents)
    out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval())
                    if z else next(nz_out_tangents_)
                    for p, z in zip(out_primals, out_zeros)]
    assert next(nz_out_tangents_, None) is None
    primal_errs = xs[num_consts:num_consts+num_errs]
    tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
    return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
  return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)

def custom_vjp_call_rule(in_err, enabled_errors, *in_vals,
                         call_jaxpr: core.ClosedJaxpr,
                         fwd_jaxpr_thunk, num_consts,
                         bwd: lu.WrappedFun, out_trees,
                         symbolic_zeros: bool):
  err_vals, err_tree = jtu.tree_flatten(in_err)
  num_errs = err_tree.num_leaves
  checkified_fun = lu.wrap_init(
      functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
                        call_jaxpr.consts, enabled_errors, err_tree),
      debug_info=call_jaxpr.jaxpr.debug_info)
  checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk(
      checkified_fun)

  def checkified_fwd(*args):
    # TODO(lenamartens, sharadmv): why not checkify here?
    xs, zeros = args[::2], args[1::2]
    xs, zeros = xs[num_errs:], zeros[num_errs:]
    fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*zeros)
    xs_without_consts = xs[num_consts:]
    return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)

  # TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr
  checkified_fwd_wrapped = lu.wrap_init(checkified_fwd,
                                        debug_info=fwd_jaxpr_thunk.debug_info)
  bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)),
                      debug_info=bwd.debug_info)
  checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped)
  all_outs = custom_derivatives.custom_vjp_call_p.bind(
      checkified_fun, checkified_fwd_wrapped,
      bwd_, *err_vals, *in_vals, out_trees=out_trees,
      symbolic_zeros=symbolic_zeros)
  fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
  if fst:
    err_and_out_tree, _ = out_metadata
    out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
  else:
    out_err, out_vals = in_err, all_outs
  return out_err, out_vals
error_checks[custom_derivatives.custom_vjp_call_p] = custom_vjp_call_rule


def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
  del debug
  new_error = tree_unflatten(err_tree, args)
  # Split up new_error into error to be functionalized if it's included in
  # enabled_errors (=discharged_error) and an error to be defunctionalized if
  # it's not included (=recharged_error)
  discharged_error = error
  recharged_error = init_error
  for error_effect in new_error._pred.keys():
    pred = new_error._pred[error_effect]
    code = new_error._code[error_effect]
    payload = new_error._payload[error_effect]
    if error_effect.error_type in enabled_errors:
      discharged_error = update_error(discharged_error, pred, code, {}, payload,
                                      error_effect)
    else:
      recharged_error = update_error(recharged_error, pred, code, {}, payload,
                                     error_effect)

  discharged_error = discharged_error._replace(
      _metadata={**new_error._metadata, **discharged_error._metadata})
  recharged_error = recharged_error._replace(_metadata=new_error._metadata)
  # TODO(lenamartens): we actually need to recharge, but this would be a
  # breaking API change so leaving for a follow-up.
  # check_error(recharged_error)
  return discharged_error, []
error_checks[check_p] = check_discharge_rule


## checkify public api

user_checks = frozenset({FailedCheckError})
nan_checks = frozenset({NaNError})
index_checks = frozenset({OOBError})
div_checks = frozenset({DivisionByZeroError})
float_checks = nan_checks | div_checks
automatic_checks = float_checks | index_checks
all_checks = automatic_checks | user_checks


def checkify(f: Callable[..., Out],
             errors: frozenset[ErrorCategory] = user_checks
             ) -> Callable[..., tuple[Error, Out]]:
  """Functionalize `check` calls in `fun`, and optionally add run-time error checks.

  Run-time errors are either user-added :func:`~check` assertions, or
  automatically added checks like NaN checks, depending on the ``errors``
  argument.

  The returned function will return an Error object `err` along with the output
  of the original function. ``err.get()`` will either return ``None`` (if no
  error occurred) or a string containing an error message. This error message
  will correspond to the first error which occurred. ``err.throw()`` will raise
  a ValueError with the error message if an error occurred.

  By default only user-added :func:`~check` assertions are enabled. You can
  enable automatic checks through the ``errors`` argument.

  The automatic check sets which can be enabled, and when an error is generated:
    - ``user_checks``: a :func:`~check` evaluated to False.
    - ``nan_checks``: a floating-point operation generated a NaN value
      as output.
    - ``div_checks``: a division by zero.
    - ``index_checks``: an index was out-of-bounds.

  Multiple categories can be enabled together by passing in an error `Set` (eg.
  ``errors=nan_checks``). Multiple sets can be re-combined (eg.
  ``errors=float_checks|user_checks``)

  Args:
    fun: Callable which can contain user checks (see :func:`~check`).
    errors: A set of ErrorCategory values which defines the set of enabled
      checks. By default only explicit ``checks`` are enabled
      (``user_checks``). You can also for example enable NAN and
      DIV errors by passing the ``float_checks`` set, or for
      example combine multiple sets through set operations
      (``float_checks | user_checks``)
  Returns:
    A function which accepts the same arguments as ``fun`` and returns as output
    a pair where the first element is an ``Error`` value, representing the first
    failed :func:`~check`, and the second element is the original output of
    ``fun``.

  For example:

    >>> import jax
    >>> import jax.numpy as jnp
    >>> from jax.experimental import checkify
    >>>
    >>> @jax.jit
    ... def f(x):
    ...   y = jnp.sin(x)
    ...   return x+y
    >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
    >>> err.throw()  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
      ...
    jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
  """
  @traceback_util.api_boundary
  def checked_fun(*args, **kwargs):
    # close over all arguments so they're not turned into abstract values.
    in_tree = jtu.tree_structure(())
    closed_f = lambda: f(*args, **kwargs)
    # stage:
    debug_info = api_util.debug_info("checkify", f, args, kwargs).with_unknown_names()
    jaxpr_, out_tree = pe.trace_to_jaxpr(closed_f, in_tree, (), debug_info)
    jaxpr, consts = pe.separate_consts(jaxpr_)
    # checkify:
    error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
    return error, jtu.tree_unflatten(out_tree, out_flat)
  return checked_fun

def check(pred: Bool, msg: str,
          *fmt_args,
          debug: bool = False,
          **fmt_kwargs,
          ) -> None:
  """Check a predicate, add an error with msg if predicate is False.

  This is an effectful operation, and can't be staged (jitted/scanned/...).
  Before staging a function with checks, :func:`~checkify` it!

  Args:
    pred: if False, a FailedCheckError error is added.
    msg: error message if error is added. Can be a format string.
    debug: Whether to turn on debugging mode. If True, check will be removed
      during execution. If False, the the check must be functionalized using
      checkify.checkify.
    fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
      `msg`, eg.:
      ``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
      Note that these arguments can be traced values allowing you to add
      run-time values to the error message.
      Note that tracking these run-time arrays will increase your memory usage,
      even if no error happens.

  For example:

    >>> import jax
    >>> import jax.numpy as jnp
    >>> from jax.experimental import checkify
    >>> def f(x):
    ...   checkify.check(x>0, "{x} needs to be positive!", x=x)
    ...   return 1/x
    >>> checked_f = checkify.checkify(f)
    >>> err, out = jax.jit(checked_f)(-3.)
    >>> err.throw()  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
      ...
    jax._src.checkify.JaxRuntimeError: -3. needs to be positive!

  """
  _check(pred, msg, debug, *fmt_args, **fmt_kwargs)

def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
  if not is_scalar_pred(pred):
    prim_name = 'debug_check' if debug else 'check'
    raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
  for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)):
    if not isinstance(arg, (Array, np.ndarray)):
      raise TypeError('Formatting arguments to checkify.check need to be '
                      'PyTrees of arrays, but got '
                      f'{arg!r} of type {type(arg)}.')
  new_error = FailedCheckError(get_traceback(), msg, *fmt_args, **fmt_kwargs)
  error = assert_func(init_error, jnp.logical_not(pred), new_error)
  _check_error(error, debug=debug)

def _check_error(error, *, debug=False):
  if any(map(np.shape, error._pred.values())):
    error = _reduce_any_error(error)
  err_args, tree_def = tree_flatten(error)

  return check_p.bind(*err_args, err_tree=tree_def, debug=debug)


def is_scalar_pred(pred) -> bool:
  return (isinstance(pred, bool) or
          isinstance(pred, Array) and pred.shape == () and
          pred.dtype == np.dtype('bool'))


def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
  """Check a predicate when running under checkify, otherwise is a no-op.

  A `debug_check` will only be run if it is transformed by :func:`~checkify`,
  otherwise the check will be dropped.

  Args:
    pred: if False, a FailedCheckError error is added.
    msg: error message if error is added.
    fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
      `msg`, eg.:
      ``debug_check(.., "check failed on values {} and {named}", x, named=y)``
      Note that these arguments can be traced values allowing you to add
      run-time values to the error message.
      Note that tracking these run-time arrays will increase your memory usage,
      even if no error happens.

  For example:

    >>> import jax
    >>> import jax.numpy as jnp
    >>> from jax.experimental import checkify
    >>> def f(x):
    ...   checkify.debug_check(x!=0, "cannot be zero!")
    ...   return x
    >>> _ = f(0)  # running without checkify means no debug_check is run.
    >>> checked_f = checkify.checkify(f)
    >>> err, out = jax.jit(checked_f)(0)  # running with checkify runs debug_check.
    >>> err.throw()  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
      ...
    jax._src.checkify.JaxRuntimeError: cannot be zero!

  """
  _check(pred, msg, True, *fmt_args, **fmt_kwargs)


def check_error(error: Error) -> None:
  """Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.

  The semantics of this function are equivalent to:

  >>> def check_error(err: Error) -> None:
  ...   err.throw()  # can raise ValueError

  But unlike that implementation, ``check_error`` can be functionalized using
  the :func:`~checkify` transformation.

  This function is similar to :func:`~check` but with a different signature: whereas
  :func:`~check` takes as arguments a boolean predicate and a new error message
  string, this function takes an ``Error`` value as argument. Both :func:`~check`
  and this function raise a Python Exception on failure (a side-effect), and
  thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
  :func:`~jax.lax.scan`, etc. Both also can
  be functionalized by using :func:`~checkify`.

  But unlike :func:`~check`, this function is like a direct inverse of
  :func:`~checkify`:
  whereas :func:`~checkify` takes as input a function which
  can raise a Python
  Exception and produces a new function without that effect but which produces
  an ``Error`` value as output, this ``check_error`` function can accept an
  ``Error`` value as input and can produce the side-effect of raising an
  Exception. That is, while :func:`~checkify` goes from
  functionalizable Exception
  effect to error value, this ``check_error`` goes from error value to
  functionalizable Exception effect.

  ``check_error`` is useful when you want to turn checks represented by an
  ``Error`` value (produced by functionalizing ``checks`` via
  :func:`~checkify`) back into Python Exceptions.

  Args:
    error: Error to check.

  For example, you might want to functionalize part of your program through
  checkify, stage out your functionalized code through :func:`~jax.jit`, then
  re-inject your error value outside of the :func:`~jax.jit`:

  >>> import jax
  >>> from jax.experimental import checkify
  >>> def f(x):
  ...   checkify.check(x>0, "must be positive!")
  ...   return x
  >>> def with_inner_jit(x):
  ...   checked_f = checkify.checkify(f)
  ...   # a checkified function can be jitted
  ...   error, out = jax.jit(checked_f)(x)
  ...   checkify.check_error(error)
  ...   return out
  >>> _ = with_inner_jit(1)  # no failed check
  >>> with_inner_jit(-1)  # doctest: +IGNORE_EXCEPTION_DETAIL
  Traceback (most recent call last):
    ...
  jax._src.JaxRuntimeError: must be positive!
  >>> # can re-checkify
  >>> error, _ = checkify.checkify(with_inner_jit)(-1)
  """
  if not isinstance(error, Error):
    raise TypeError('check_error takes an Error as argument, '
                     f'got type {type(error)} instead.')
  _check_error(error, debug=False)
